/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the
 * "License"). Please refer to the License for details. You may not use this
 * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN
 * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS
 * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository
 * for the full text of the License.
 */

#ifndef EXAMPLES_NORMALIZATION_WELFORDUPDATE_CUSTOM_H
#define EXAMPLES_NORMALIZATION_WELFORDUPDATE_CUSTOM_H
#include "kernel_operator.h"

namespace MyCustomKernel {
struct VecTiling {
    bool inplace;
    uint32_t nLength;
    uint32_t rLength;
    uint32_t abComputeLength;
    float nRec;
    uint32_t tmpLocalSize;
};

constexpr uint8_t LOCAL_BYTES = 32;

template <typename T, typename U, bool isReuseSource = false, bool tmpLocal = true>
class KernelWelfordUpdate {
public:
    __aicore__ inline KernelWelfordUpdate() {}
    __aicore__ inline void Init(GM_ADDR inputX_gm, GM_ADDR inputMean_gm, GM_ADDR inputVar_gm, GM_ADDR outputMean_gm,
        GM_ADDR outputVar_gm, VecTiling tilingData) {
        nLength = tilingData.nLength;
        rLength = tilingData.rLength;
        abComputeLength = tilingData.abComputeLength;
        nRec = tilingData.nRec;
        bshLength = tilingData.nLength * tilingData.rLength;
        inplace = tilingData.inplace;
        tmpLocalBytes = tilingData.tmpLocalSize;

        inputX_global.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(inputX_gm), bshLength);
        inputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ U *>(inputMean_gm), bshLength);
        inputVar_global.SetGlobalBuffer(reinterpret_cast<__gm__ U *>(inputVar_gm), bshLength);

        outputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ U *>(outputMean_gm), bshLength);
        outputVar_global.SetGlobalBuffer(reinterpret_cast<__gm__ U *>(outputVar_gm), bshLength);

        pipe.InitBuffer(inQueueX, 1, sizeof(T) * bshLength);
        pipe.InitBuffer(inQueueMean, 1, sizeof(U) * bshLength);
        pipe.InitBuffer(inQueueVar, 1, sizeof(U) * bshLength);
        pipe.InitBuffer(outQueueMean, 1, sizeof(U) * bshLength);
        pipe.InitBuffer(outQueueVar, 1, sizeof(U) * bshLength);
    }
    __aicore__ inline void Process() {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn() {
        AscendC::LocalTensor<T> inputXLocal = inQueueX.AllocTensor<T>();
        AscendC::LocalTensor<U> inMeanLocal = inQueueMean.AllocTensor<U>();
        AscendC::LocalTensor<U> inVarLocal = inQueueVar.AllocTensor<U>();

        AscendC::DataCopy(inputXLocal, inputX_global, bshLength);
        AscendC::DataCopy(inMeanLocal, inputMean_global, bshLength);
        AscendC::DataCopy(inVarLocal, inputVar_global, bshLength);

        inQueueX.EnQue(inputXLocal);
        inQueueMean.EnQue(inMeanLocal);
        inQueueVar.EnQue(inVarLocal);
    }
    __aicore__ inline void Compute() {
        AscendC::LocalTensor<T> inputXLocal = inQueueX.DeQue<T>();
        AscendC::LocalTensor<U> inMeanLocal = inQueueMean.DeQue<U>();
        AscendC::LocalTensor<U> inVarLocal = inQueueVar.DeQue<U>();

        AscendC::LocalTensor<U> outMeanLocal = outQueueMean.AllocTensor<U>();
        AscendC::LocalTensor<U> outVarLocal = outQueueVar.AllocTensor<U>();
        static constexpr AscendC::WelfordUpdateConfig WELFORD_UPDATE_ENABLE_INPLACE_CFG = {true};
        static constexpr AscendC::WelfordUpdateConfig WELFORD_UPDATE_UNENABLE_INPLACE_CFG = {false};
        struct AscendC::WelfordUpdateParam para = {nLength, rLength, abComputeLength, nRec};
        if (!tmpLocal) {
            if (inplace) {
                AscendC::WelfordUpdate<T, U, isReuseSource, WELFORD_UPDATE_ENABLE_INPLACE_CFG>(outMeanLocal,
                    outVarLocal, inMeanLocal, inVarLocal, inputXLocal, para);
            } else {
                AscendC::WelfordUpdate<T, U, isReuseSource, WELFORD_UPDATE_UNENABLE_INPLACE_CFG>(outMeanLocal,
                    outVarLocal, inMeanLocal, inVarLocal, inputXLocal, para);
            }
        } else {
            if (tmpLocalBytes % LOCAL_BYTES != 0) {
                tmpLocalBytes = (tmpLocalBytes + LOCAL_BYTES - 1) / LOCAL_BYTES * LOCAL_BYTES;
            }
            pipe.InitBuffer(tmpLocalBuf, tmpLocalBytes);
            AscendC::LocalTensor<uint8_t> tmpLocalTensor = tmpLocalBuf.Get<uint8_t>();
            if (inplace) {
                AscendC::WelfordUpdate<T, U, isReuseSource, WELFORD_UPDATE_ENABLE_INPLACE_CFG>(outMeanLocal,
                    outVarLocal, inMeanLocal, inVarLocal, inputXLocal, tmpLocalTensor, para);
            } else {
                AscendC::WelfordUpdate<T, U, isReuseSource, WELFORD_UPDATE_UNENABLE_INPLACE_CFG>(outMeanLocal,
                    outVarLocal, inMeanLocal, inVarLocal, inputXLocal, tmpLocalTensor, para);
            }
        }

        inQueueX.FreeTensor(inputXLocal);
        inQueueMean.FreeTensor(inMeanLocal);
        inQueueVar.FreeTensor(inVarLocal);

        outQueueMean.EnQue(outMeanLocal);
        outQueueVar.EnQue(outVarLocal);
    }
    __aicore__ inline void CopyOut() {
        AscendC::LocalTensor<U> outMeanLocal = outQueueMean.DeQue<U>();
        AscendC::LocalTensor<U> outVarLocal = outQueueVar.DeQue<U>();

        AscendC::DataCopy(outputMean_global, outMeanLocal, bshLength);
        AscendC::DataCopy(outputVar_global, outVarLocal, bshLength);

        outQueueMean.FreeTensor(outMeanLocal);
        outQueueVar.FreeTensor(outVarLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueMean;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueVar;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueMean;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueVar;

    AscendC::GlobalTensor<T> inputX_global;
    AscendC::GlobalTensor<U> inputMean_global;
    AscendC::GlobalTensor<U> inputVar_global;
    AscendC::GlobalTensor<U> outputMean_global;
    AscendC::GlobalTensor<U> outputVar_global;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpLocalBuf;

    uint32_t tmpLocalBytes = 0;
    uint32_t nLength;
    uint32_t rLength;
    uint32_t abComputeLength;
    float nRec;
    uint32_t bshLength;
    bool inplace;
};

} // namespace MyCustomKernel

#endif // EXAMPLES_NORMALIZATION_WELFORDUPDATE_CUSTOM_H
